using BenchmarkTools, OrdinaryDiffEq, Sundials, SparseArrays, LinearAlgebra

const cs = (
    const_m5042476316481538586 = [[0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.        ]
    [0.03237701]],
    const_m1664908561639471167 = [[-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]],
    const_m1742632654512844898 = [[ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [ 0.        ]
    [-0.14182856]],
    const_6902631343103379854 = sparse([ 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9,10,10,11,11,12,12,
    13,13,14,14,15,15,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,
    25,25,26,26,27,27,28,28,29,29], [ 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9,10,10,11,11,12,12,13,
    13,14,14,15,15,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,
    25,26,26,27,27,28,28,29,29,30], [-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.], 29, 30),
    const_m2887110228888269587 = sparse([ 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,
    26,27,28,29,30], [ 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,
    25,26,27,28,29], [1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
    1.,1.,1.,1.,1.], 31, 29),
    const_948199332974845790 = [[-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]
    [-1.]],
    const_m3416783780236702828 = [[0.        ]
    [0.00111111]
    [0.00444444]
    [0.01      ]
    [0.01777778]
    [0.02777778]
    [0.04      ]
    [0.05444444]
    [0.07111111]
    [0.09      ]
    [0.11111111]
    [0.13444444]
    [0.16      ]
    [0.18777778]
    [0.21777778]
    [0.25      ]
    [0.28444444]
    [0.32111111]
    [0.36      ]
    [0.40111111]
    [0.44444444]
    [0.49      ]
    [0.53777778]
    [0.58777778]
    [0.64      ]
    [0.69444444]
    [0.75111111]
    [0.81      ]
    [0.87111111]
    [0.93444444]
    [1.        ]],
    const_m4709902773269181539 = sparse([ 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9,10,10,11,11,12,12,
    13,13,14,14,15,15,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,
    25,25,26,26,27,27,28,28,29,29,30,30], [ 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9,10,10,11,11,12,12,13,
    13,14,14,15,15,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,
    25,26,26,27,27,28,28,29,29,30,30,31], [-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,-30., 30.,
    -30., 30.,-30., 30.], 30, 31),
    const_m6560247611775222723 = [[3.60000000e+03]
    [4.00000000e+02]
    [1.44000000e+02]
    [7.34693878e+01]
    [4.44444444e+01]
    [2.97520661e+01]
    [2.13017751e+01]
    [1.60000000e+01]
    [1.24567474e+01]
    [9.97229917e+00]
    [8.16326531e+00]
    [6.80529301e+00]
    [5.76000000e+00]
    [4.93827160e+00]
    [4.28061831e+00]
    [3.74609781e+00]
    [3.30578512e+00]
    [2.93877551e+00]
    [2.62965668e+00]
    [2.36686391e+00]
    [2.14158239e+00]
    [1.94699838e+00]
    [1.77777778e+00]
    [1.62969670e+00]
    [1.49937526e+00]
    [1.38408304e+00]
    [1.28159487e+00]
    [1.19008264e+00]
    [1.10803324e+00]
    [1.03418558e+00]],
    const_7523862830560677743 = 4.272493084154669,
    cache = zeros(29),
    cache2 = zeros(31),
    cache3 = zeros(30),)

function f_pybamm(dy, y, p, t, c)
    dy[1] = c.const_7523862830560677743
    mul!(c.cache,c.const_6902631343103379854,@view y[2:31])
    mul!(c.cache2,c.const_m2887110228888269587,c.cache)
    c.cache2 .= c.const_m3416783780236702828 .* c.const_948199332974845790 .* c.cache2 .+ c.const_m1742632654512844898
    mul!(c.cache3,c.const_m4709902773269181539,c.cache2)
    dy[2:31] .= -8.813457647415214 .* c.const_m6560247611775222723 .* c.cache3

    mul!(c.cache,c.const_6902631343103379854,@view y[32:61])
    mul!(c.cache2,c.const_m2887110228888269587,c.cache)
    c.cache2 .= c.const_m3416783780236702828 .* c.const_948199332974845790 .* c.cache2 .+ c.const_m1742632654512844898
    mul!(c.cache3,c.const_m4709902773269181539,c.cache2)
    dy[32:61] .= -22.598609352346706 .* c.const_m6560247611775222723 .* c.cache3
end
f(dy,y,p,t) = f_pybamm(dy, y, p, t, cs)

u0 = [0.0; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.8000000000000016; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6; 0.6]
tspan = (0.0, 0.15)
prob = ODEProblem(f, u0, tspan)

@btime solve(prob, KenCarp47(autodiff=false), reltol=1e-8, abstol=1e-8, saveat=0.15 / 100)

@btime solve(prob, CVODE_BDF(), reltol=1e-8, abstol=1e-8, saveat=0.15 / 100)

# Python (PyBAMM -> CASADI -> Sundials)
~ 7.5 ms

# Julia before optimization
2.527 ms (13965 allocations: 1.62 MiB)

# After w/ Sundials (C++)
1.145 ms (3235 allocations: 188.14 KiB)

# After with OrdinaryDiffEq KenCarp47 (Pure Julia)
743.099 μs (493 allocations: 155.91 KiB)

using Profile
@profile for i in 1:1000 solve(prob, TRBDF2(), reltol=1e-8, abstol=1e-8, saveat=0.15 / 100) end
Juno.profiler()
Profile.clear()
